Attention

多头缩放点积注意力机制(Scaled Dot-Product Attention)

\[\text{Attention}(Q, K, V) = \operatorname{softmax}\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V\]
输入:
  • Q - 查询矩阵地址(行优先),形状 \([B, H, L, D]\) 展平。

  • K - 键矩阵地址(行优先),形状 \([B, H, L, D]\) 展平。

  • V - 值矩阵地址(行优先),形状 \([B, H, L, D]\) 展平。

  • params - 其余参数打包成数组。

  • core_mask(可选) - 核掩码(仅适用于共享存储版本)。

输出:
  • output - 输出地址(行优先),形状 \([B, H, L, D]\) 展平。

支持平台:

FT78NE MT7004

备注

  • 当前实现基于 fp32;输入/中间/输出缓冲区不应重叠。

  • 内存布局为行优先(row-major)。

参数数组结构:

1int params[10];
2params[0] = batch_size; 批大小
3params[1] = seq_len; 序列长度
4params[2] = head_num; 多头数量
5params[3] = head_dim; 每头通道维数
6params[4] = QK地址的低32位;
7params[5] = QK地址的高32位;
8params[6] = 中间缓冲区地址的低32位;
9params[7] = 中间缓冲区地址的高32位;

共享存储版本:

void fp_attention_s(float *Q, float *K, float *V, float *output, int *params, int core_mask)

C调用示例:

 1#include <stdio.h>
 2
 3int main(int argc, char* argv[]) {
 4    int B = 2, L = 128, H = 8, D = 64;
 5    float *Q = (float *)0xA0000000;      // DDR
 6    float *K = (float *)0xA1000000;      // DDR
 7    float *V = (float *)0xA2000000;      // DDR
 8    float *O = (float *)0xA3000000;      // DDR
 9    float *QK = (float *)0xA4000000;     // DDR
10    float *SM = (float *)0xA5000000;     // DDR
11    int core_mask = 0xff;
12    int params[10];
13    params[0] = B;
14    params[1] = L;
15    params[2] = H;
16    params[3] = D;
17    params[4] = (int) (uint32_t) (uintptr_t) QK;
18    params[5] = (int) (uint32_t) (uintptr_t) QK >> 32;
19    params[6] = (int) (uint32_t) (uintptr_t) SM;
20    params[7] = (int) (uint32_t) (uintptr_t) SM >> 32;
21    fp_attention_s(Q, K, V, O, params, core_mask);
22    return 0;
23}

私有存储版本:

void fp_attention_p(float *Q, float *K, float *V, float *output, int *params)

C调用示例:

 1#include <stdio.h>
 2
 3int main(int argc, char* argv[]) {
 4    int B = 1, L = 64, H = 4, D = 32;
 5    float *Q = (float *)0x10000000;   // L2
 6    float *K = (float *)0x10040000;   // L2
 7    float *V = (float *)0x10080000;   // L2
 8    float *O = (float *)0x100C0000;   // L2
 9    float *QK = (float *)0x10100000;  // L2
10    float *SM = (float *)0x10200000;  // L2
11    int params[10];
12    params[0] = B;
13    params[1] = L;
14    params[2] = H;
15    params[3] = D;
16    params[4] = (int) (uint32_t) (uintptr_t) QK;
17    params[5] = (int) (uint32_t) (uintptr_t) QK >> 32;
18    params[6] = (int) (uint32_t) (uintptr_t) SM;
19    params[7] = (int) (uint32_t) (uintptr_t) SM >> 32;
20    fp_attention_p(Q, K, V, O, params);
21    return 0;
22}